import logging
from dataclasses import dataclass, field
from enum import IntEnum
from typing import Any, Dict, List, Literal, Optional, Tuple

import yaml
from mpi4py.MPI import COMM_WORLD, Comm

from .._utils import global_mpi_rank, global_mpi_size

__all__ = [
    'ServerConfig',
    'parse_disagg_config_file',
    'extract_server_configs',
    'split_world_comm',
]


class ServerRole(IntEnum):
    CONTEXT = 0
    GENERATION = 1
    MM_ENCODER = 2


@dataclass
class CtxGenServerConfig():
    type: Literal['ctx', 'gen']
    hostname: Optional[str] = None
    port: Optional[int] = None
    instance_num_ranks: int = 1
    other_args: dict = field(default_factory=dict)


@dataclass
class RouterConfig():
    type: str = "round_robin"
    args: dict = field(default_factory=dict)
    server_role: ServerRole = None


@dataclass
class ConditionalDisaggConfig():
    max_local_prefill_length: int = 0


@dataclass
class OtlpConfig():
    otlp_traces_endpoint: Optional[
        str] = None  # Target URL to which OpenTelemetry traces will be sent


@dataclass
class MinimalInstances:
    context_servers: int = 1  # the minimal number of context servers
    generation_servers: int = 1  # the minimal number of generation servers


@dataclass
class DisaggClusterConfig:
    cluster_uri: str  # the uri of the cluster storage
    cluster_name: str = ""  # the name of the cluster, used like a namespace
    minimal_instances: Optional[MinimalInstances] = None
    heartbeat_interval_sec: int = 5  # the worker will send heartbeat to the cluster storage every heartbeat_interval_sec seconds
    inactive_timeout_sec: int = 10  # the worker will be considered inactive if it doesn't send heartbeat for inactive_timeout_sec seconds


@dataclass
class DisaggServerConfig():
    server_configs: List[CtxGenServerConfig]
    hostname: str = "localhost"
    port: int = 8000
    ctx_router_config: Optional[RouterConfig] = None
    gen_router_config: Optional[RouterConfig] = None
    conditional_disagg_config: Optional[ConditionalDisaggConfig] = None
    otlp_config: Optional[OtlpConfig] = None
    max_retries: int = 1
    perf_metrics_max_requests: int = 0
    disagg_cluster_config: Optional[DisaggClusterConfig] = None


@dataclass
class MetadataServerConfig():
    server_type: Literal['etcd']
    hostname: str = "localhost"
    port: int = 2379
    health_check_timeout: float = 5.0
    refresh_interval: float = 10.0


def get_ctx_gen_server_addrs(
        server_configs: list[CtxGenServerConfig]
) -> tuple[list[str], list[str]]:
    ctx_server_urls = []
    gen_server_urls = []
    for cfg in server_configs:
        if cfg.type == "ctx":
            ctx_server_urls.append(f"{cfg.hostname}:{cfg.port}")
        else:
            gen_server_urls.append(f"{cfg.hostname}:{cfg.port}")

    return ctx_server_urls, gen_server_urls


def parse_disagg_config_file(yaml_config_file: str):

    with open(yaml_config_file, 'r') as file:

        config = yaml.safe_load(file)

        disagg_server_config = extract_disagg_cfg(**config)

        return disagg_server_config


def extract_disagg_cfg(hostname: str = 'localhost',
                       port: int = 8000,
                       max_retries: int = 1,
                       perf_metrics_max_requests: int = 0,
                       context_servers: Optional[dict] = None,
                       generation_servers: Optional[dict] = None,
                       conditional_disagg_config: Optional[dict] = None,
                       otlp_config: Optional[dict] = None,
                       disagg_cluster: Optional[dict] = None,
                       **kwargs: Any) -> DisaggServerConfig:
    context_servers = context_servers or {}
    generation_servers = generation_servers or {}

    # If parameters are specified outside the context_severs and generation_servers sections,
    # make sure they match
    # Also inherit the values from the top-level
    for key, value in kwargs.items():
        for server_type, servers in [("context_servers", context_servers),
                                     ("generation_servers", generation_servers)
                                     ]:
            if key in servers:
                if servers[key] != value:
                    raise ValueError(
                        f"Parameter {key} is specified both in the top-level and in the {server_type} section, but with different values"
                    )
            else:
                # Inherit the value from the top-level
                servers[key] = value

    server_configs = []
    disagg_cluster_config = None
    ctx_router_config = extract_router_config(context_servers)
    gen_router_config = extract_router_config(generation_servers)
    ctx_router_config.server_role = ServerRole.CONTEXT
    gen_router_config.server_role = ServerRole.GENERATION
    if disagg_cluster:
        disagg_cluster_config = extract_disagg_cluster_config(disagg_cluster)
    else:
        server_configs = extract_ctx_gen_cfgs(
            type="ctx", **context_servers) + extract_ctx_gen_cfgs(
                type="gen", **generation_servers)

    conditional_disagg_config = ConditionalDisaggConfig(
        **conditional_disagg_config) if conditional_disagg_config else None

    otlp_config = OtlpConfig(**otlp_config) if otlp_config else None

    config = DisaggServerConfig(server_configs, hostname, port,
                                ctx_router_config, gen_router_config,
                                conditional_disagg_config, otlp_config,
                                max_retries, perf_metrics_max_requests,
                                disagg_cluster_config)

    return config


def extract_ctx_gen_cfgs(type: Literal['ctx', 'gen'],
                         num_instances: int = 1,
                         urls: Optional[List[str]] = None,
                         **kwargs: Any) -> List[CtxGenServerConfig]:

    hostnames = []
    ports = []
    if urls:
        for url in urls:
            hostname, port_str = url.split(':')
            port = int(port_str)
            hostnames.append(hostname)
            ports.append(port)

        if len(hostnames) != num_instances:
            raise ValueError(
                f"Number of hostnames ({len(hostnames)}) should be equal to the number of instances ({num_instances})"
            )

        if len(ports) != num_instances:
            raise ValueError(
                f"Number of ports ({len(ports)}) should be equal to the number of instances ({num_instances})"
            )

    else:
        hostnames = [None] * num_instances
        ports = [None] * num_instances

    # Compute the number of ranks per instance
    instance_num_ranks = kwargs.get('tensor_parallel_size', 1) * kwargs.get(
        'pipeline_parallel_size', 1) * kwargs.get('context_parallel_size', 1)

    cfgs = []
    for hostname, port in zip(hostnames, ports):
        cfgs.append(
            CtxGenServerConfig(type=type,
                               hostname=hostname,
                               port=port,
                               instance_num_ranks=instance_num_ranks,
                               other_args=kwargs))
    return cfgs


def extract_router_config(server_cfg: dict) -> RouterConfig:

    args = server_cfg.pop("router", {})
    router_type = args.pop("type", "round_robin")

    # add fields that are not specific to router
    extract_keys = ["max_batch_size", "max_num_tokens"]
    for key in extract_keys:
        if key in server_cfg:
            args[key] = server_cfg[key]

    return RouterConfig(type=router_type, args=args)


def get_server_configs_dict(
        server_configs: List[CtxGenServerConfig]) -> Tuple[int, dict]:

    num_workers = 0
    server_dict = {}

    # check for duplicate server configs
    for cfg in server_configs:
        url = (cfg.hostname, cfg.port)
        if url in server_dict:
            cfg_prev = server_dict[url]
            if cfg_prev.type == cfg.type:
                raise ValueError(
                    f"Duplicated {cfg.type} server config for {url}")
            # mixed server, config should be the same
            if cfg_prev.other_args != cfg.other_args:
                raise ValueError(
                    f"Server config for {url} has different args:\n{cfg_prev.other_args}\n{cfg.other_args}"
                )
        else:
            server_dict[url] = cfg
            num_workers += cfg.instance_num_ranks

    return num_workers, server_dict


def extract_disagg_cluster_config(
        cluster_config_dict: Dict[str, Any],
        cluster_uri: Optional[str] = None) -> DisaggClusterConfig:
    """
    Build the DisaggClusterConfig from the cluster_config_dict.
    Use the default value of DisaggClusterConfig and MinimalInstances if the corresponding fields are not provided.
    If cluster_uri is provided, it will override the cluster_uri in the cluster_config_dict.
    """

    def update_dataclass(obj, data_dict: Dict[str, Any]):
        for key, value in data_dict.items():
            if key not in obj.__dataclass_fields__:
                raise KeyError(
                    f"Key {key} not found in {obj.__class__.__name__}")
            if value is not None:
                setattr(obj, key, value)
        return obj

    cluster_config_dict["minimal_instances"] = update_dataclass(
        MinimalInstances(), cluster_config_dict.get("minimal_instances", {}))
    cluster_config = update_dataclass(
        DisaggClusterConfig(cluster_uri or cluster_config_dict["cluster_uri"]),
        cluster_config_dict,
    )
    return cluster_config


def split_world_comm(
        server_configs: List[CtxGenServerConfig]) -> Tuple[bool, int, Comm]:

    # Check that MPI_COMM_WORLD size is compatible with the number of workers
    global_size = global_mpi_size()
    global_rank = global_mpi_rank()

    [num_workers, server_dict] = get_server_configs_dict(server_configs)
    assert global_size == num_workers, f"global_size ({global_size}) should be equal to the number of distinct workers ({num_workers})"

    # Identify the leader ranks and the instance idx for each rank
    is_leader = False
    offset = 0
    instance_idx = 0
    instance_sub_rank = 0
    for idx, cfg in enumerate(server_configs):
        if (cfg.hostname, cfg.port) not in server_dict:
            continue
        server_dict.pop((cfg.hostname, cfg.port))
        if global_rank >= offset and global_rank < offset + cfg.instance_num_ranks:
            instance_idx = idx
            instance_sub_rank = global_rank - offset
            # The first rank in each instance is the leader
            if global_rank == offset:
                is_leader = True
        offset += cfg.instance_num_ranks

    # Split MPI_COMM_WORLD into sub-communicators based on rank_instance_idx
    sub_comm = COMM_WORLD.Split(color=instance_idx, key=instance_sub_rank)
    sub_rank = sub_comm.Get_rank()
    if sub_rank != instance_sub_rank:
        raise RuntimeError(
            f"Expected sub_rank {sub_rank} to be equal to instance_sub_rank {instance_sub_rank}"
        )

    sub_comm.Barrier()

    logging.info(
        f"global_rank: {global_rank}, instance_idx: {instance_idx}, sub_rank: {sub_rank}, is_leader: {is_leader}"
    )

    return is_leader, instance_idx, sub_comm


def parse_metadata_server_config_file(
    metadata_server_config_file: Optional[str]
) -> Optional[MetadataServerConfig]:
    if metadata_server_config_file is None:
        return None

    with open(metadata_server_config_file, 'r') as file:
        config = yaml.safe_load(file)
        return MetadataServerConfig(**config)
